{
 "cells": [
  {
   "cell_type": "raw",
   "id": "bcd88fb5-6a0f-4c4b-81fd-34be59ea7903",
   "metadata": {},
   "source": [
    "模型下载链接：https://hf-mirror.com/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth?download=true"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5b78b7ef-acc6-46cf-88c2-f90a2835e4b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "########################################################################################################\n",
    "# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM\n",
    "########################################################################################################\n",
    "\n",
    "import numpy as np\n",
    "np.set_printoptions(precision=4, suppress=True, linewidth=200)\n",
    "import types, torch\n",
    "from torch.nn import functional as F\n",
    "from tokenizers import Tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "deacc22b-2896-4b77-b595-3284b0c13544",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = Tokenizer.from_file(\"20B_tokenizer.json\")\n",
    "\n",
    "args = types.SimpleNamespace()\n",
    "args.MODEL_NAME = '/data1/ckw/RWKV-4-Pile-430M-20220808-8066'\n",
    "args.n_layer = 24\n",
    "args.n_embd = 1024\n",
    "\n",
    "context = \"\\nDataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence.\"\n",
    "NUM_TRIALS = 3\n",
    "LENGTH_PER_TRIAL = 100\n",
    "TEMPERATURE = 1.0\n",
    "TOP_P = 0.85\n",
    "########################################################################################################"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c85bca7-1342-4d8b-870c-baddcf2661d6",
   "metadata": {},
   "source": [
    "### RWKV 的时间混合实现\n",
    "\n",
    "在 RWKV 模型中，时间混合（Time Mixing）是一个关键步骤，用于处理输入序列随时间的变化。以下是 `time_mixing` 函数的详细公式说明和代码注释。\n",
    "\n",
    "#### 公式说明\n",
    "\n",
    "时间混合的核心思想是通过时间混合系数将当前输入与先前的状态混合，以生成新的键、值和门控信号。这一过程涉及如下步骤：\n",
    "\n",
    "1. **混合输入**：\n",
    "   - 对当前输入 \\( x \\) 和前一状态进行加权平均：\n",
    "     $$ x_k = x \\cdot \\text{time\\_mix\\_k} + \\text{state}[5i+1] \\cdot (1 - \\text{time\\_mix\\_k}) $$\n",
    "     $$ x_v = x \\cdot \\text{time\\_mix\\_v} + \\text{state}[5i+1] \\cdot (1 - \\text{time\\_mix\\_v}) $$\n",
    "     $$ x_r = x \\cdot \\text{time\\_mix\\_r} + \\text{state}[5i+1] \\cdot (1 - \\text{time\\_mix\\_r}) $$\n",
    "\n",
    "2. **状态更新**：\n",
    "   - 更新状态：\n",
    "     $$ \\text{state}[5i+1] = x $$\n",
    "\n",
    "3. **计算门控信号**：\n",
    "   - 使用 sigmoid 激活函数计算门控信号 \\( r \\)：\n",
    "     $$ r = \\sigma(\\text{rw} @ x_r) $$\n",
    "\n",
    "4. **计算键和值**：\n",
    "   - 通过线性变换生成键 \\( k \\) 和值 \\( v \\)：\n",
    "     $$ k = \\text{kw} @ x_k $$\n",
    "     $$ v = \\text{vw} @ x_v $$\n",
    "\n",
    "5. **加权和计算**：\n",
    "   - 根据加权和公式计算加权和 \\( wkv \\)：\n",
    "     $$ a = e1 \\cdot aa + e2 \\cdot v $$\n",
    "     $$ b = e1 \\cdot bb + e2 $$\n",
    "     $$ \\text{wkv} = a / b $$\n",
    "\n",
    "代码如下：\n",
    "\n",
    "```python\n",
    "@torch.jit.script_method\n",
    "def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):\n",
    "    # 混合当前输入和先前的状态\n",
    "    xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)\n",
    "    xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)\n",
    "    xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)\n",
    "\n",
    "    # 更新状态\n",
    "    state[5*i+1] = x\n",
    "\n",
    "    # 计算门控信号\n",
    "    r = torch.sigmoid(rw @ xr)\n",
    "    \n",
    "    # 计算键和值\n",
    "    k = kw @ xk\n",
    "    v = vw @ xv\n",
    "\n",
    "    # 从状态中读取先前的累积值\n",
    "    aa = state[5*i+2]\n",
    "    bb = state[5*i+3]\n",
    "    pp = state[5*i+4]\n",
    "\n",
    "    # 计算加权和的第一部分\n",
    "    ww = time_first + k\n",
    "    qq = torch.maximum(pp, ww)\n",
    "    e1 = torch.exp(pp - qq)\n",
    "    e2 = torch.exp(ww - qq)\n",
    "    a = e1 * aa + e2 * v\n",
    "    b = e1 * bb + e2\n",
    "    wkv = a / b\n",
    "\n",
    "    # 计算新的权重和状态\n",
    "    ww = pp + time_decay\n",
    "    qq = torch.maximum(ww, k)\n",
    "    e1 = torch.exp(ww - qq)\n",
    "    e2 = torch.exp(k - qq)\n",
    "    state[5*i+2] = e1 * aa + e2 * v\n",
    "    state[5*i+3] = e1 * bb + e2\n",
    "    state[5*i+4] = qq\n",
    "\n",
    "    # 计算最终的输出\n",
    "    return ow @ (r * wkv)\n",
    "```\n",
    "\n",
    "### 详细解释\n",
    "\n",
    "1. **混合输入**：\n",
    "   - `xk`, `xv`, `xr` 是输入 `x` 与状态 `state` 的加权混合，分别用于计算键、值和门控信号。\n",
    "\n",
    "2. **状态更新**：\n",
    "   - 将当前输入 `x` 存储在状态数组中，供下一步计算使用。\n",
    "\n",
    "3. **计算门控信号**：\n",
    "   - 使用 `torch.sigmoid` 计算门控信号 `r`，它决定了多少信息将被传递。\n",
    "\n",
    "4. **计算键和值**：\n",
    "   - 使用矩阵乘法计算键 `k` 和值 `v`。\n",
    "\n",
    "5. **加权和计算**：\n",
    "   - 通过指数加权平均计算加权和 `wkv`，这涉及处理数值稳定性问题（通过 `torch.maximum` 和指数运算）。\n",
    "\n",
    "6. **更新状态**：\n",
    "   - 更新状态数组中的累积值，以便后续时间步使用。\n",
    "\n",
    "7. **计算最终输出**：\n",
    "   - 使用门控信号 `r` 和加权和 `wkv` 计算最终输出。\n",
    "\n",
    "这样，通过逐步混合当前输入和先前的状态，RWKV 模型实现了时间序列数据的有效处理。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f0d47e4-1792-47e4-a506-3071f510526e",
   "metadata": {},
   "source": [
    "### RWKV 的通道混合（Channel Mixing）实现与代码注释\n",
    "\n",
    "在 RWKV 模型中，通道混合（Channel Mixing）是另一个关键步骤，用于处理不同通道之间的信息交换。以下是 `channel_mixing` 函数的详细公式说明和代码注释。\n",
    "\n",
    "#### 公式说明\n",
    "\n",
    "通道混合的核心思想是通过通道混合系数将当前输入与先前的状态混合，以生成新的键和门控信号。这一过程涉及如下步骤：\n",
    "\n",
    "1. **混合输入**：\n",
    "   - 对当前输入 \\( x \\) 和前一状态进行加权平均：\n",
    "     $$ x_k = x \\cdot \\text{time\\_mix\\_k} + \\text{state}[5i+0] \\cdot (1 - \\text{time\\_mix\\_k}) $$\n",
    "     $$ x_r = x \\cdot \\text{time\\_mix\\_r} + \\text{state}[5i+0] \\cdot (1 - \\text{time\\_mix\\_r}) $$\n",
    "\n",
    "2. **状态更新**：\n",
    "   - 更新状态：\n",
    "     $$ \\text{state}[5i+0] = x $$\n",
    "\n",
    "3. **计算门控信号**：\n",
    "   - 使用 sigmoid 激活函数计算门控信号 \\( r \\)：\n",
    "     $$ r = \\sigma(\\text{rw} @ x_r) $$\n",
    "\n",
    "4. **计算键**：\n",
    "   - 通过 ReLU 和平方变换生成键 \\( k \\)：\n",
    "     $$ k = (\\text{ReLU}(\\text{kw} @ x_k))^2 $$\n",
    "\n",
    "5. **计算输出**：\n",
    "   - 使用门控信号和键计算最终的输出：\n",
    "     $$ \\text{output} = r \\cdot (\\text{vw} @ k) $$\n",
    "\n",
    "代码如下：\n",
    "\n",
    "```python\n",
    "@torch.jit.script_method\n",
    "def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
    "    # 混合当前输入和先前的状态\n",
    "    xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)\n",
    "    xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)\n",
    "\n",
    "    # 更新状态\n",
    "    state[5*i+0] = x\n",
    "\n",
    "    # 计算门控信号\n",
    "    r = torch.sigmoid(rw @ xr)\n",
    "\n",
    "    # 计算键，并通过ReLU和平方变换\n",
    "    k = torch.square(torch.relu(kw @ xk))  # square relu, primer paper\n",
    "\n",
    "    # 计算最终的输出\n",
    "    return r * (vw @ k)\n",
    "```\n",
    "\n",
    "\n",
    "1. **混合输入**：\n",
    "   - `xk`, `xr` 是输入 `x` 与状态 `state` 的加权混合，分别用于计算键和门控信号。\n",
    "\n",
    "2. **状态更新**：\n",
    "   - 将当前输入 `x` 存储在状态数组中，供下一步计算使用。\n",
    "\n",
    "3. **计算门控信号**：\n",
    "   - 使用 `torch.sigmoid` 计算门控信号 `r`，它决定了多少信息将被传递。\n",
    "\n",
    "4. **计算键**：\n",
    "   - 使用 `torch.relu` 计算键 `k`，然后进行平方变换以增加非线性特性。\n",
    "\n",
    "5. **计算最终输出**：\n",
    "   - 使用门控信号 `r` 和键 `k` 计算最终输出。\n",
    "\n",
    "通过这些步骤，RWKV 模型实现了通道间的信息有效交换，增强了模型对输入数据的处理能力。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0f1b2e2b-9f0d-4db3-b9d9-d43e3e2537ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RWKV_RNN(torch.jit.ScriptModule):\n",
    "    def __init__(self, args):\n",
    "        super().__init__()\n",
    "        self.args = args\n",
    "        self.eval() # set torch to inference mode\n",
    "        \n",
    "        w = torch.load(args.MODEL_NAME + '.pth', map_location='cpu')\n",
    "        for k in w.keys():\n",
    "            if      '.time_' in k: w[k] = w[k].squeeze()\n",
    "            if '.time_decay' in k: w[k] = -torch.exp(w[k].float()) # the real time decay is like e^{-e^x}\n",
    "            else: w[k] = w[k].float() # convert to f32 type\n",
    "        \n",
    "        self.w = types.SimpleNamespace() # set self.w from w\n",
    "        self.w.blocks = {}\n",
    "        for k in w.keys(): # example: \"blocks.0.att.time_first\" => self.w.blocks[0].att.time_first\n",
    "            parts = k.split('.')\n",
    "            last = parts.pop()\n",
    "            here = self.w\n",
    "            for p in parts:\n",
    "                if p.isdigit():\n",
    "                    p = int(p)\n",
    "                    if p not in here: here[p] = types.SimpleNamespace()\n",
    "                    here = here[p]\n",
    "                else:\n",
    "                    if not hasattr(here, p): setattr(here, p, types.SimpleNamespace())\n",
    "                    here = getattr(here, p)\n",
    "            setattr(here, last, w[k])\n",
    "\n",
    "    def layer_norm(self, x, w):\n",
    "        return F.layer_norm(x, (self.args.n_embd,), weight=w.weight, bias=w.bias)\n",
    "\n",
    "    @torch.jit.script_method\n",
    "    def channel_mixing(self, x, state, i:int, time_mix_k, time_mix_r, kw, vw, rw):\n",
    "        xk = x * time_mix_k + state[5*i+0] * (1 - time_mix_k)\n",
    "        xr = x * time_mix_r + state[5*i+0] * (1 - time_mix_r)\n",
    "        state[5*i+0] = x\n",
    "        r = torch.sigmoid(rw @ xr)\n",
    "        k = torch.square(torch.relu(kw @ xk)) # square relu, primer paper\n",
    "        return r * (vw @ k)\n",
    "\n",
    "    @torch.jit.script_method\n",
    "    def time_mixing(self, x, state, i:int, time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, kw, vw, rw, ow):\n",
    "        xk = x * time_mix_k + state[5*i+1] * (1 - time_mix_k)\n",
    "        xv = x * time_mix_v + state[5*i+1] * (1 - time_mix_v)\n",
    "        xr = x * time_mix_r + state[5*i+1] * (1 - time_mix_r)\n",
    "        state[5*i+1] = x\n",
    "        r = torch.sigmoid(rw @ xr)\n",
    "        k = kw @ xk\n",
    "        v = vw @ xv\n",
    "        \n",
    "        aa = state[5*i+2]\n",
    "        bb = state[5*i+3]\n",
    "        pp = state[5*i+4]\n",
    "        ww = time_first + k\n",
    "        qq = torch.maximum(pp, ww)\n",
    "        e1 = torch.exp(pp - qq)\n",
    "        e2 = torch.exp(ww - qq)\n",
    "        a = e1 * aa + e2 * v\n",
    "        b = e1 * bb + e2\n",
    "        wkv = a / b\n",
    "        ww = pp + time_decay\n",
    "        qq = torch.maximum(ww, k)\n",
    "        e1 = torch.exp(ww - qq)\n",
    "        e2 = torch.exp(k - qq)\n",
    "        state[5*i+2] = e1 * aa + e2 * v\n",
    "        state[5*i+3] = e1 * bb + e2\n",
    "        state[5*i+4] = qq\n",
    "        return ow @ (r * wkv)\n",
    "\n",
    "    def forward(self, token, state):\n",
    "        with torch.no_grad():\n",
    "            if state == None:\n",
    "                state = torch.zeros(self.args.n_layer * 5, self.args.n_embd)\n",
    "                for i in range(self.args.n_layer): state[5*i+4] = -1e30 # -infinity\n",
    "            \n",
    "            x = self.w.emb.weight[token]\n",
    "            x = self.layer_norm(x, self.w.blocks[0].ln0)\n",
    "            for i in range(self.args.n_layer):\n",
    "                att = self.w.blocks[i].att\n",
    "                x = x + self.time_mixing(self.layer_norm(x, self.w.blocks[i].ln1), state, i, \n",
    "                    att.time_mix_k, att.time_mix_v, att.time_mix_r, att.time_first, att.time_decay, \n",
    "                    att.key.weight, att.value.weight, att.receptance.weight, att.output.weight)\n",
    "                ffn = self.w.blocks[i].ffn\n",
    "                x = x + self.channel_mixing(self.layer_norm(x, self.w.blocks[i].ln2), state, i, \n",
    "                    ffn.time_mix_k, ffn.time_mix_r, \n",
    "                    ffn.key.weight, ffn.value.weight, ffn.receptance.weight)\n",
    "            \n",
    "            x = self.w.head.weight @ self.layer_norm(x, self.w.ln_out)\n",
    "            return x.float(), state\n",
    "\n",
    "##########################################################################################################"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1b457af-77a3-4b5e-a6f3-034b0fc6708d",
   "metadata": {},
   "source": [
    "采样方法和v2、v3版本相比没有发生变化，代码做了一点优化调整而已。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fdf027b6-7df9-4c0f-818e-013e7c49e3cd",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_logits(out, temperature=1.0, top_p=0.8):\n",
    "    probs = F.softmax(out, dim=-1).numpy()\n",
    "    sorted_probs = np.sort(probs)[::-1]\n",
    "    cumulative_probs = np.cumsum(sorted_probs)\n",
    "    cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])\n",
    "    probs[probs < cutoff] = 0\n",
    "    if temperature != 1.0:\n",
    "        probs = probs.pow(1.0 / temperature)\n",
    "    probs = probs / np.sum(probs)\n",
    "    out = np.random.choice(a=len(probs), p=probs)\n",
    "    return out\n",
    "\n",
    "########################################################################################################"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "298dbbde-6535-406b-bd43-f2d886799f8c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Using CPU. Loading /data1/ckw/RWKV-4-Pile-430M-20220808-8066 ...\n"
     ]
    }
   ],
   "source": [
    "print(f'\\nUsing CPU. Loading {args.MODEL_NAME} ...')\n",
    "model = RWKV_RNN(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7d366a89-02cb-4b5e-95ef-52f6376d3607",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Preprocessing context (slow version. see v2/rwkv/model.py for fast version)\n"
     ]
    }
   ],
   "source": [
    "print(f'\\nPreprocessing context (slow version. see v2/rwkv/model.py for fast version)')\n",
    "init_state = None\n",
    "for token in tokenizer.encode(context).ids:\n",
    "    init_out, init_state = model.forward(token, init_state)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5273e7a8-875e-4998-b98e-f81951a7af32",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "--[ Trial 0 ]----------------- \n",
      "DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The machine learning solutions applied to the class are called Persona, which consist of several categories:\n",
      "\n",
      "\\begin{tabular}{|c|c|c|}\n",
      "\\hline\n",
      "  Name   & Description  \\\\ \\hline\n",
      "\\hline\n",
      "  \\end{tabular}\n",
      "\n",
      "DataWhalechina organizes the data in two ways:\n",
      "\n",
      "\\begin{tabular}{|c|c|c|}\n",
      "\\hline\n",
      "  \\multicolumn{2}{|c}{\n",
      "\n",
      "--[ Trial 1 ]----------------- \n",
      "DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The main goal is to allow learners to learn how to use artificial intelligence in an integrated fashion, by using both AI and deep learning techniques. Datawhalechina aims to teach AI algorithms from scratch and teach them from scratch to become competent with many algorithms that humans could not have.\n",
      "\n",
      "Applications\n",
      "\n",
      "Projects \n",
      " DeeplearningAI : Encourage AI algorithms to become competent with many algorithms that humans could not have. Datawhalechina aims to be able to combine knowledge from multiple AI\n",
      "\n",
      "--[ Trial 2 ]----------------- \n",
      "DataWhalechina is an organization founded at Shanghai Jiao Tong University that helps learners learn artificial intelligence. The company was founded in 2016. The company has graduated 1,000 engineers, who work from the companies headquarters in Shanghai.\n",
      "\n",
      "In September 2019, the team was reported to have learned over 400,000 artificial intelligence.\n",
      "\n",
      "In August 2019, the company was reported to have sold 600,000 artificial intelligence to clients in Singapore.\n",
      "\n",
      "References\n",
      "\n",
      "External links\n",
      " \n",
      "\n",
      "Category:Human machine interaction\n",
      "Category:Learning management systems\n",
      "Category:Learning management systemsTechnologies, industry,\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for TRIAL in range(NUM_TRIALS):\n",
    "    print(f'\\n\\n--[ Trial {TRIAL} ]-----------------', context, end=\"\")\n",
    "    all_tokens = []\n",
    "    out_last = 0\n",
    "    out, state = init_out.clone(), init_state.clone()\n",
    "    for i in range(LENGTH_PER_TRIAL):\n",
    "        token = sample_logits(out, TEMPERATURE, TOP_P)\n",
    "        all_tokens += [token]\n",
    "        tmp = tokenizer.decode(all_tokens[out_last:])\n",
    "        if '\\ufffd' not in tmp: # only print when we have a valid utf-8 string\n",
    "            print(tmp, end=\"\", flush=True)\n",
    "            out_last = i + 1\n",
    "        out, state = model.forward(token, state)       \n",
    "print('\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f1cdc809-c64d-4861-b540-460bc1097e38",
   "metadata": {},
   "source": [
    "### 备注：RWKV 的Scaling Law（缩放定律）\n",
    "\n",
    "RWKV 的缩放定律描述了模型性能随着各种因素变化的数学关系。这些因素包括模型大小（$N$）、数据集大小（$D$）或最优计算预算（$C_{\\min}$）。缩放定律的重要性体现在以下两个方面：\n",
    "1. **预测与规划**：它们允许我们在训练大型模型之前，通过插值和外推来预测和规划成本和性能。\n",
    "2. **反馈与研究**：它们提供了关于模型失效情况下的重要反馈，指引未来研究方向。\n",
    "\n",
    "#### 关键内容总结：\n",
    "- **与之前的RNN研究对比**：之前的工作指出，LSTM不完全遵循与Transformer相同的对数线性缩放定律。然而，RWKV模型的训练结果表明，RWKV遵循与Transformer相同的一般缩放定律形式。\n",
    "- **实验验证**：在[v4的论文](https://arxiv.org/abs/2305.13048)通过训练45个RWKV模型，验证了其损失与计算量之间的线性关系，线性拟合的 $r^2$ 值为0.994，即使外推一个数量级，拟合度仍然很好（$r^2$为0.875）。\n",
    "\n",
    "这些结果显示了RWKV模型在缩放时的优越性和与Transformer相似的性能缩放行为。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13f6025d-faea-4647-be05-8fb4cce05991",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
